#!/usr/bin/env python3
"""
Phase 6b: Joint Permutation Test for Income Structure of the Coefficient Surface

The per-step permutation test (phase 6, test 5) checks each step independently.
This script computes joint test statistics that capture the *consistency* of the
income-ordered path being elevated relative to random orderings.

Three joint statistics:
  1. Mean Z₁ across all steps (area under the coefficient path)
  2. Fraction of steps where income-ordered Z₁ > random median
  3. Maximum Z₁ at any step (peak of the path)

For each statistic, we compute a joint p-value: the fraction of 1000 random
permutations that equal or exceed the income-ordered value.

Also computes a Fisher-style combined test from per-step percentile ranks.
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats as sp_stats

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS

FOLLOWUP_DIR = PROJECT_DIR / "multilateral" / "followup"
TABLE_DIR = PROJECT_DIR / "fragility" / "output" / "tables"
FIG_DIR = PROJECT_DIR / "fragility" / "output" / "figures"

N_PERMUTATIONS = 1000
TARGET_DV = 'ca_gdp'
BASELINE_VARS = ['Z_1', 'Z_2', 'Z_3', 'fiscal_bal_gdp', 'kaopen',
                 'nfa_gdp_lag', 'log_rel_opw']
np.random.seed(42)

# ── Load and prepare panel ────────────────────────────────────────────────
print("=" * 70)
print("PHASE 6b: JOINT PERMUTATION TEST")
print("=" * 70)

panel = pd.read_csv(FOLLOWUP_DIR / "data" / "processed" / "full_panel.csv")
panel = panel[(panel['year'] >= 1986) & (panel['year'] <= 2024)].copy()

mean_gdp = panel.groupby('iso3')['gdp_pc_ppp'].mean().dropna()
country_income_rank = mean_gdp.sort_values(ascending=False)  # richest first

avail_vars = [v for v in BASELINE_VARS if v in panel.columns]
ca_panel = panel.dropna(subset=[TARGET_DV] + avail_vars).copy()
ca_countries_all = list(ca_panel['iso3'].unique())
ca_countries_income_ranked = [c for c in country_income_rank.index
                              if c in ca_panel['iso3'].unique()]

print(f"Panel: {len(ca_countries_all)} countries with CA data")


def fit_z1_fast(data, dep_var, indep_vars):
    """Fit PanelGLS and return Z₁ coefficient only."""
    avail = [v for v in indep_vars if v in data.columns]
    comp = data.dropna(subset=[dep_var] + avail)
    if comp['iso3'].nunique() < 5 or len(comp) < 30:
        return np.nan
    gls = PanelGLS()
    gls.fit(comp[dep_var].values, comp[avail].values,
            comp['iso3'].values, comp['year'].values)
    z1_idx = avail.index('Z_1')
    return gls.beta[z1_idx]


# ── Define steps ──────────────────────────────────────────────────────────
steps = list(range(20, len(ca_countries_all) + 1, 20))
if steps[-1] != len(ca_countries_all):
    steps.append(len(ca_countries_all))
# Cap at countries we have income data for
steps = [s for s in steps if s <= len(ca_countries_income_ranked)]

print(f"Steps: {steps}")
print(f"Countries with income rank: {len(ca_countries_income_ranked)}")

# ── Income-ordered path ──────────────────────────────────────────────────
print("\nComputing income-ordered path...")
income_path = {}
for step in steps:
    countries_in = set(ca_countries_income_ranked[:step])
    sub = ca_panel[ca_panel['iso3'].isin(countries_in)]
    z1 = fit_z1_fast(sub, TARGET_DV, BASELINE_VARS)
    income_path[step] = z1
    print(f"  Step {step:3d}: Z₁ = {z1:.2f}")

income_z1s = np.array([income_path[s] for s in steps])


# ── Random permutation paths (storing full paths for joint tests) ─────────
print(f"\nRunning {N_PERMUTATIONS} random permutations...")
perm_paths = np.full((N_PERMUTATIONS, len(steps)), np.nan)

for perm in range(N_PERMUTATIONS):
    shuffled = list(ca_countries_all)
    np.random.shuffle(shuffled)

    for j, step in enumerate(steps):
        if step > len(shuffled):
            break
        countries_in = set(shuffled[:step])
        sub = ca_panel[ca_panel['iso3'].isin(countries_in)]
        perm_paths[perm, j] = fit_z1_fast(sub, TARGET_DV, BASELINE_VARS)

    if (perm + 1) % 100 == 0:
        print(f"  Permutation {perm + 1}/{N_PERMUTATIONS}")


# ══════════════════════════════════════════════════════════════════════════
# JOINT TEST STATISTICS
# ══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("JOINT TEST RESULTS")
print("=" * 70)

# Exclude steps where income path or most permutations have NaN
valid_mask = ~np.isnan(income_z1s)
for j in range(len(steps)):
    if np.isnan(perm_paths[:, j]).sum() > N_PERMUTATIONS * 0.5:
        valid_mask[j] = False

valid_steps = [steps[j] for j in range(len(steps)) if valid_mask[j]]
valid_income = income_z1s[valid_mask]
valid_perm = perm_paths[:, valid_mask]

print(f"Valid steps: {valid_steps}")

# Fill remaining NaNs in permutations with column median for robustness
for j in range(valid_perm.shape[1]):
    col = valid_perm[:, j]
    nan_mask = np.isnan(col)
    if nan_mask.any():
        col[nan_mask] = np.nanmedian(col)

# ── Statistic 1: Mean Z₁ across steps (area under curve) ─────────────
income_mean = np.mean(valid_income)
perm_means = np.mean(valid_perm, axis=1)
p_mean = np.mean(perm_means >= income_mean)

print(f"\n1. MEAN Z₁ ACROSS PATH (area under curve):")
print(f"   Income-ordered: {income_mean:.2f}")
print(f"   Random median:  {np.median(perm_means):.2f}")
print(f"   Random 95th:    {np.percentile(perm_means, 95):.2f}")
print(f"   Joint p-value:  {p_mean:.4f}")

# ── Statistic 2: Fraction of steps above random median ────────────────
step_medians = np.median(valid_perm, axis=0)
income_above = np.mean(valid_income > step_medians)

perm_above = np.array([
    np.mean(valid_perm[i, :] > step_medians)
    for i in range(N_PERMUTATIONS)
])
p_above = np.mean(perm_above >= income_above)

print(f"\n2. FRACTION OF STEPS ABOVE RANDOM MEDIAN:")
print(f"   Income-ordered: {income_above:.0%} ({int(income_above * len(valid_steps))}/{len(valid_steps)} steps)")
print(f"   Random median:  {np.median(perm_above):.0%}")
print(f"   Joint p-value:  {p_above:.4f}")

# ── Statistic 3: Maximum Z₁ at any step ──────────────────────────────
income_max = np.max(valid_income)
perm_maxes = np.max(valid_perm, axis=1)
p_max = np.mean(perm_maxes >= income_max)

print(f"\n3. MAXIMUM Z₁ ALONG PATH:")
print(f"   Income-ordered: {income_max:.2f}")
print(f"   Random median:  {np.median(perm_maxes):.2f}")
print(f"   Random 95th:    {np.percentile(perm_maxes, 95):.2f}")
print(f"   Joint p-value:  {p_max:.4f}")

# ── Statistic 4: Weighted mean (weight early steps more) ─────────────
# Early steps (small, rich samples) are where the surface diverges most
weights = 1.0 / np.arange(1, len(valid_steps) + 1)  # 1, 1/2, 1/3, ...
weights = weights / weights.sum()

income_wmean = np.average(valid_income, weights=weights)
perm_wmeans = np.array([np.average(valid_perm[i, :], weights=weights)
                        for i in range(N_PERMUTATIONS)])
p_wmean = np.mean(perm_wmeans >= income_wmean)

print(f"\n4. WEIGHTED MEAN (early-step emphasis):")
print(f"   Income-ordered: {income_wmean:.2f}")
print(f"   Random median:  {np.median(perm_wmeans):.2f}")
print(f"   Random 95th:    {np.percentile(perm_wmeans, 95):.2f}")
print(f"   Joint p-value:  {p_wmean:.4f}")

# ── Statistic 5: Fisher combined test from per-step percentile ranks ──
# Under H0 (income = random), per-step ranks are Uniform(0,1)
# Fisher statistic: -2 * sum(log(1 - rank_i))
# Under H0, follows chi-squared with 2k degrees of freedom
per_step_ranks = np.array([
    np.mean(valid_perm[:, j] <= valid_income[j])
    for j in range(len(valid_steps))
])

# Clip to avoid log(0)
ranks_clipped = np.clip(per_step_ranks, 0.001, 0.999)
fisher_stat = -2 * np.sum(np.log(1 - ranks_clipped))
fisher_df = 2 * len(valid_steps)
fisher_p = 1 - sp_stats.chi2.cdf(fisher_stat, fisher_df)

print(f"\n5. FISHER COMBINED TEST (from per-step ranks):")
print(f"   Per-step ranks: {', '.join(f'{r:.2f}' for r in per_step_ranks)}")
print(f"   Mean rank:      {np.mean(per_step_ranks):.3f} (0.50 under H0)")
print(f"   Fisher stat:    {fisher_stat:.2f} (chi2 with {fisher_df} df)")
print(f"   Fisher p-value: {fisher_p:.4f}")

# ── Statistic 6: Monotone decline test ────────────────────────────────
# The surface predicts that adding poorer countries should reduce Z₁
# Test: is the slope of Z₁ on step-number more negative for income-ordered
# than for random orderings?
from numpy.polynomial.polynomial import polyfit
income_slope = polyfit(np.arange(len(valid_steps)), valid_income, 1)[1]
perm_slopes = np.array([
    polyfit(np.arange(len(valid_steps)), valid_perm[i, :], 1)[1]
    for i in range(N_PERMUTATIONS)
])
p_slope = np.mean(perm_slopes <= income_slope)  # one-sided: income should be MORE negative

print(f"\n6. DECLINING SLOPE TEST:")
print(f"   Income-ordered slope: {income_slope:.3f} per step")
print(f"   Random median slope:  {np.median(perm_slopes):.3f}")
print(f"   Joint p-value (one-sided, slope ≤ income): {p_slope:.4f}")


# ══════════════════════════════════════════════════════════════════════════
# COMBINED SUMMARY
# ══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("COMBINED SUMMARY")
print("=" * 70)

results = [
    ('Mean Z₁ (AUC)', income_mean, p_mean),
    ('Fraction above median', income_above, p_above),
    ('Maximum Z₁', income_max, p_max),
    ('Weighted mean (early emphasis)', income_wmean, p_wmean),
    ('Fisher combined', fisher_stat, fisher_p),
    ('Declining slope', income_slope, p_slope),
]

print(f"\n{'Statistic':<35s} {'Value':>10s} {'p-value':>10s}")
print("-" * 57)
for name, val, p in results:
    sig = '***' if p < 0.001 else '**' if p < 0.01 else '*' if p < 0.05 \
        else '†' if p < 0.10 else ''
    print(f"  {name:<33s} {val:10.3f} {p:10.4f} {sig}")

# Any significant at joint level?
n_sig_05 = sum(1 for _, _, p in results if p < 0.05)
n_sig_10 = sum(1 for _, _, p in results if p < 0.10)
print(f"\n  {n_sig_05}/6 significant at p<0.05")
print(f"  {n_sig_10}/6 significant at p<0.10")


# ── Save results ──────────────────────────────────────────────────────
rows = []
for name, val, p in results:
    rows.append({'statistic': name, 'value': val, 'p_value': p})
results_df = pd.DataFrame(rows)
results_df.to_csv(TABLE_DIR / "table10b_joint_permutation_test.csv", index=False)

# Markdown
md = ["# Table 10b: Joint Permutation Tests for Income Structure\n"]
md.append(f"CA/GDP. {N_PERMUTATIONS} random country orderings. Tests whether the "
          "income-ordered coefficient path is *jointly* different from random orderings, "
          "not just at individual steps.\n")
md.append("| Statistic | Income-ordered | p-value |")
md.append("|:---|---:|---:|")
for name, val, p in results:
    sig = '\\*\\*\\*' if p < 0.001 else '\\*\\*' if p < 0.01 else '\\*' if p < 0.05 \
        else '†' if p < 0.10 else ''
    md.append(f"| {name} | {val:.3f} | {p:.4f}{sig} |")
md.append(f"\n**Interpretation**: Under H₀, income ordering is no different from random. "
          f"The mean Z₁ across the path, the fraction of steps above the random median, "
          f"and the Fisher combined rank test all assess whether the income-ordered path "
          f"is *consistently* elevated — not just at one step.")
md.append(f"\n**Per-step percentile ranks**: {', '.join(f'{r:.0%}' for r in per_step_ranks)}")
md.append(f"\n**Mean rank**: {np.mean(per_step_ranks):.1%} (50% under H₀)")

md_text = "\n".join(md)
(TABLE_DIR / "table10b_joint_permutation_test.md").write_text(md_text)
print(f"\nSaved table10b_joint_permutation_test")


# ── Figure: Joint test visualization ──────────────────────────────────
try:
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(1, 3, figsize=(16, 5))

    # Panel A: Distribution of path means
    ax = axes[0]
    ax.hist(perm_means, bins=40, color='gray', alpha=0.6, edgecolor='white',
            label='Random orderings')
    ax.axvline(income_mean, color='red', linewidth=2.5,
               label=f'Income-ordered ({income_mean:.1f})')
    ax.axvline(np.percentile(perm_means, 95), color='gray', linewidth=1,
               linestyle='--', label='95th percentile')
    ax.set_xlabel('Mean Z₁ across path', fontsize=11)
    ax.set_ylabel('Count', fontsize=11)
    ax.set_title(f'A. Path Mean (p={p_mean:.3f})', fontsize=12, fontweight='bold')
    ax.legend(fontsize=9)

    # Panel B: Distribution of fraction above median
    ax = axes[1]
    ax.hist(perm_above, bins=20, color='gray', alpha=0.6, edgecolor='white',
            label='Random orderings')
    ax.axvline(income_above, color='red', linewidth=2.5,
               label=f'Income-ordered ({income_above:.0%})')
    ax.set_xlabel('Fraction of steps above random median', fontsize=11)
    ax.set_ylabel('Count', fontsize=11)
    ax.set_title(f'B. Consistency (p={p_above:.3f})', fontsize=12, fontweight='bold')
    ax.legend(fontsize=9)

    # Panel C: Distribution of slopes
    ax = axes[2]
    ax.hist(perm_slopes, bins=40, color='gray', alpha=0.6, edgecolor='white',
            label='Random orderings')
    ax.axvline(income_slope, color='red', linewidth=2.5,
               label=f'Income-ordered ({income_slope:.3f})')
    ax.set_xlabel('Slope of Z₁ path', fontsize=11)
    ax.set_ylabel('Count', fontsize=11)
    ax.set_title(f'C. Declining slope (p={p_slope:.3f})', fontsize=12, fontweight='bold')
    ax.legend(fontsize=9)

    plt.suptitle('Joint Permutation Tests: Income-Ordered vs Random Paths',
                 fontsize=14, y=1.02)
    plt.tight_layout()
    fig.savefig(FIG_DIR / "figure5_joint_permutation_tests.png",
                dpi=150, bbox_inches='tight')
    plt.close()
    print("  Saved figure5_joint_permutation_tests.png")

except ImportError:
    print("  matplotlib not available — skipping figure")

print("\nPhase 6b complete.")
